-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][amdgpu] Add conversion from arith.scaling_{extf,truncf} to amdgpu #146372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Tim Gymnich (tgymnich) ChangesPatch is 52.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146372.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3596b3235a631..22cd4703c6005 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -15,11 +15,17 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
@@ -32,6 +38,7 @@ using namespace mlir::amdgpu;
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +80,28 @@ struct TruncfToFloat16RewritePattern final
PatternRewriter &rewriter) const override;
};
+struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ Chipset chipset;
+ ScalingExtFRewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+ LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ Chipset chipset;
+ ScalingTruncFRewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+ LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+ PatternRewriter &rewriter) const override;
+};
+
} // end namespace
static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +424,137 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
return success();
}
+static Value getOriginalVectorValue(Value value) {
+ Value current = value;
+ while (Operation *definingOp = current.getDefiningOp()) {
+ bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+ .Case<vector::ShapeCastOp>(
+ [¤t](auto op) {
+ current = op.getSource();
+ return true;
+ })
+ .Case<vector::BroadcastOp>(
+ [¤t](auto op) {
+ current = op.getSource();
+ return false;
+ })
+ .Case<vector::SplatOp>(
+ [¤t](auto op) {
+ current = op.getInput();
+ return false;
+ })
+ .Default([](Operation *) {
+ return false;
+ });
+
+ if (!skipOp) {
+ break;
+ }
+ }
+ return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ constexpr const int64_t opWidth = 2;
+
+ Value in = op.getIn();
+ Value scale = op.getScale();
+ Value out = op.getOut();
+
+ Type f32 = rewriter.getF32Type();
+ Type inType = getElementTypeOrSelf(in);
+ Type scaleType = getElementTypeOrSelf(scale);
+ Type outType = getElementTypeOrSelf(out);
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+ VectorType inVecType = dyn_cast<VectorType>(in.getType());
+ VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+ if (outVecType && outVecType.isScalable())
+ return failure();
+
+ Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+ if (scaleType.getIntOrFloatBitWidth() < 32)
+ scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ else if (scaleType.getIntOrFloatBitWidth() > 32)
+ scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+ VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+ if (!outVecType) {
+ Value inCast = rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(loc, extScaleResultType, inCast, scale, 0);
+ scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+ return success();
+ }
+
+ Value origScale = getOriginalVectorValue(scale);
+ Type origScaleType = origScale.getType();
+ VectorType origScaleVecType = isa<VectorType>(origScaleType) ? cast<VectorType>(origScaleType) : VectorType::get(1, origScaleType);
+
+ ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+ ArrayRef<int64_t> inShape = inVecType.getShape();
+
+ SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+ paddedScaleShape.insert(paddedScaleShape.end(), inShape.size() - originalScaleShape.size(),
+ 1);
+
+ auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+ if (!ratio)
+ return failure();
+
+ const int64_t blockSize = computeProduct(*ratio);
+
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+ for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
+ SmallVector<int64_t> strides(offsets.size(), 1);
+ Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, in, offsets, *ratio, strides);
+ VectorType block1DType = VectorType::get(blockSize, inType);
+ Value block1D =
+ rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ Value uniformScale =
+ rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+ VectorType blockResultType = VectorType::get(blockSize, outType);
+ Value blockResult =
+ rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+ for (int64_t i = 0, sliceWidth = opWidth - blockSize % opWidth;
+ i < blockSize;
+ i += sliceWidth, sliceWidth = opWidth - blockSize % opWidth) {
+ Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, block1D, i, sliceWidth, 1);
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+ loc, extScaleResultType, slice, uniformScale, 0);
+ if (sliceWidth != opWidth)
+ scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, scaleExt, 0, sliceWidth, 1);
+ blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, scaleExt, blockResult, i, 1);
+ }
+
+ VectorType resultType = VectorType::get(*ratio, outType);
+ Value cast = rewriter.create<vector::ShapeCastOp>(loc, resultType,
+ blockResult);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, cast, result, offsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
+
+ return success();
+}
+
+LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const {
+ return success();
+}
+
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +566,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
}
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+
+ if (chipset >= kGfx950) {
+ patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), chipset);
+ patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), chipset);
+ }
}
void ArithToAMDGPUConversionPass::runOnOperation() {
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
new file mode 100644
index 0000000000000..1669926ae48cc
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
@@ -0,0 +1,553 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_fallback
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT: [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT: [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT: [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT: [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT: [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT: [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT: [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT: [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT: [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT: [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT: [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT: [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT: [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT: [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT: [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT: [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT: [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT: return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f8_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+ %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+ return %ext : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @conversion_f4_fallback
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT: [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT: [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT: [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT: [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT: [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT: [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT: [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT: [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT: [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT: [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT: [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT: [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT: [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT: [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT: [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT: [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT: [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT: return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f4_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+ %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+ return %ext : vector<2x2xf32>
+}
+
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32>
+// CHECK-NEXT: [[BCAST:%.+]] = vector.broadcast %arg1 : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+// CHECK-NEXT: [[IN_CAST:%.+]] = vector.shape_cast %arg0 : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+// CHECK-NEXT: [[SCALE_CAST:%.+]] = vector.shape_cast [[BCAST]] : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf [[SCALE_CAST]] : vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+// CHECK-NEXT: [[IN_SLICE_0:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_0:%.+]] = vector.shape_cast [[IN_SLICE_0]]
+// CHECK-NEXT: [[SCALE_SCALAR_0:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT: [[PACKED_0:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_0]][0], [[SCALE_SCALAR_0]]
+// CHECK-NEXT: [[OUT_SLICE_0:%.+]] = vector.extract_strided_slice [[PACKED_0]]
+// CHECK-NEXT: [[OUT_SCALAR_0:%.+]] = vector.shape_cast [[OUT_SLICE_0]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_0]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_1:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_1:%.+]] = vector.shape_cast [[IN_SLICE_1]]
+// CHECK-NEXT: [[SCALE_SCALAR_1:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 1]
+// CHECK-NEXT: [[PACKED_1:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_1]][0], [[SCALE_SCALAR_1]]
+// CHECK-NEXT: [[OUT_SLICE_1:%.+]] = vector.extract_strided_slice [[PACKED_1]]
+// CHECK-NEXT: [[OUT_SCALAR_1:%.+]] = vector.shape_cast [[OUT_SLICE_1]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_1]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_2:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_2:%.+]] = vector.shape_cast [[IN_SLICE_2]]
+// CHECK-NEXT: [[SCALE_SCALAR_2:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 2]
+// CHECK-NEXT: [[PACKED_2:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_2]][0], [[SCALE_SCALAR_2]]
+// CHECK-NEXT: [[OUT_SLICE_2:%.+]] = vector.extract_strided_slice [[PACKED_2]]
+// CHECK-NEXT: [[OUT_SCALAR_2:%.+]] = vector.shape_cast [[OUT_SLICE_2]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_2]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_3:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_3:%.+]] = vector.shape_cast [[IN_SLICE_3]]
+// CHECK-NEXT: [[SCALE_SCALAR_3:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 3]
+// CHECK-NEXT: [[PACKED_3:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_3]][0], [[SCALE_SCALAR_3]]
+// CHECK-NEXT: [[OUT_SLICE_3:%.+]] = vector.extract_strided_slice [[PACKED_3]]
+// CHECK-NEXT: [[OUT_SCALAR_3:%.+]] = vector.shape_cast [[OUT_SLICE_3]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_3]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_4:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_4:%.+]] = vector.shape_cast [[IN_SLICE_4]]
+// CHECK-NEXT: [[SCALE_SCALAR_4:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT: [[PACKED_4:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_4]][0], [[SCALE_SCALAR_4]]
+// CHECK-NEXT: [[OUT_SLICE_4:%.+]] = vector.extract_strided_slice [[PACKED_4]]
+// CHECK-NEXT: [[OUT_SCALAR_4:%.+]] = vector.shape_cast [[OUT_SLICE_4]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_4]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_5:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_5:%.+]] = vector.shape_cast [[IN_SLICE_5]]
+// CHECK-NEXT: [[SCALE_SCALAR_5:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 1]
+// CHECK-NEXT: [[PACKED_5:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_5]][0], [[SCALE_SCALAR_5]]
+// CHECK-NEXT: [[OUT_SLICE_5:%.+]] = vector.extract_strided_slice [[PACKED_5]]
+// CHECK-NEXT: [[OUT_SCALAR_5:%.+]] = vector.shape_cast [[OUT_SLICE_5]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_5]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_6:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_6:%.+]] = vector.shape_cast [[IN_SLICE_6]]
+// CHECK-NEXT: [[SCALE_SCALAR_6:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 2]
+// CHECK-NEXT: [[PACKED_6:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_6]][0], [[SCALE_SCALAR_6]]
+// CHECK-NEXT: [[OUT_SLICE_6:%.+]] = vector.extract_strided_slice [[PACKED_6]]
+// CHECK-NEXT: [[OUT_SCALAR_6:%.+]] = vector.shape_cast [[OUT_SLICE_6]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_6]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_7:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_7:%.+]] = vector.shape_cast [[IN_SLICE_7]]
+// CHECK-NEXT: [[SCALE_SCALAR_7:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 3]
+// CHECK-NEXT: [[PACKED_7:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_7]][0], [[SCALE_SCALAR_7]]
+// CHECK-NEXT: [[OUT_SLICE_7:%.+]] = vector.extract_strided_slice [[PACKED_7]]
+// CHECK-NEXT: [[OUT_SCALAR_7:%.+]] = vector.shape_cast [[OUT_SLICE_7]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_7]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_8:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_8:%.+]] = vector.shape_cast [[IN_SLICE_8]]
+// CHECK-NEXT: [[SCALE_SCALAR_8:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 0]
+// CHECK-NEXT: [[PACKED_8:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_8]][0], [[SCALE_SCALAR_8]]
+// CHECK-NEXT: [[OUT_SLICE_8:%.+]] = vector.extract_strided_slice [[PACKED_8]]
+// CHECK-NEXT: [[OUT_SCALAR_8:%.+]] = vector.shape_cast [[OUT_SLICE_8]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_8]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_9:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_9:%.+]] = vector.shape_cast [[IN_SLICE_9]]
+// CHECK-NEXT: [[SCALE_SCALAR_9:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 1]
+// CHECK-NEXT: [[PACKED_9:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_9]][0], [[SCALE_SCALAR_9]]
+// CHECK-NEXT: [[OUT_SLICE_9:%.+]]...
[truncated]
|
591ac0d
to
3fee107
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) { | |
for (const SmallVector<int64_t> &offsets : StaticTileOffsetRange(inShape, *ratio)) { |
let's not copy vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this vector is created in-place on the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd still copy the elements. There may be copy elision but I didn't check if it is guaranteed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vector is moved copied either way.
37e519d
to
bad8e78
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please find a better way to check the conversion broadcast that 400 lines of check-next, this is not very maintainable :)
9ae4dbb
to
a6fbc23
Compare
a6fbc23
to
b76a3e8
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix formatting otherwise LGTM. Thanks for working on this.
b50a495
to
efc6194
Compare
amdgpu.scaled_{ext,trunc}_packed
in the future.cc @umangyadav
fixes iree-org/iree#20821